{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GMM EM算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "from utils.metric import calculate_covariance_matrix,euclidean_distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class GaussianMixtureModel():\n",
    "    def __init__(self,k=3,max_iterations=400,tolerance=1e-8):\n",
    "        self.k = k\n",
    "        self.parameters = []\n",
    "        self.max_iterations = max_iterations\n",
    "        self.tolerance = tolerance\n",
    "        self.responsibilities = []\n",
    "        self.sample_assignments = None\n",
    "        self.responsibility = None\n",
    "    def _init_random_gaussians(self, X):\n",
    "        n_samples = np.shape(X)[0]\n",
    "        self.priors = (1./self.k) * np.ones(self.k)\n",
    "        for i in range(self.k):\n",
    "            params = {}\n",
    "            params['mean'] = X[np.random.choice(range(n_samples))]\n",
    "            params[\"cov\"] = calculate_covariance_matrix(X)\n",
    "            self.parameters.append(params)\n",
    "            \n",
    "    # 计算多元高斯分布\n",
    "    def multivariate_gaussian(self, X, params):\n",
    "        n_features = np.shape(X)[1]\n",
    "        mean = params[\"mean\"]\n",
    "        covar = params[\"cov\"]\n",
    "        determinant = np.linalg.det(covar)\n",
    "        likelihoods = np.zeros(np.shape(X)[0])\n",
    "        for i, sample in enumerate(X):\n",
    "            d = n_features\n",
    "            coeff = (1./(math.pow(2*math.pi,d/2)*math.sqrt(determinant)))\n",
    "            exponent = math.exp(-0.5 * (sample - mean).T.dot(np.linalg.pinv(covar)).dot((sample - mean)))\n",
    "            likelihoods[i] = coeff * exponent\n",
    "        \n",
    "        return likelihoods\n",
    "    def _get_likelihoods(self, X):\n",
    "        n_samples = np.shape(X)[0]\n",
    "        likelihoods = np.zeros((n_samples, self.k))\n",
    "        for i in range(self.k):\n",
    "            likelihoods[:,i] = self.multivariate_gaussian(X,self.parameters[i])\n",
    "        return likelihoods\n",
    "    # E步\n",
    "    def _expectation(self, X):\n",
    "        weighted_likelihoods = self._get_likelihoods(X) * self.priors \n",
    "        sum_likelihoods = np.expand_dims(np.sum(weighted_likelihoods, axis=1), axis=1)\n",
    "        self.responsibility = weighted_likelihoods / sum_likelihoods\n",
    "        self.sample_assignments = self.responsibility.argmax(axis=1)\n",
    "        self.responsibilities.append(np.max(self.responsibility, axis=1))\n",
    "    # M步\n",
    "    def _maximization(self, X):\n",
    "        for i in range(self.k):\n",
    "            resp = np.expand_dims(self.responsibility[:, i], axis=1)\n",
    "            mean = (resp * X).sum(axis=0) / resp.sum()\n",
    "            covariance = (X - mean).T.dot((X - mean) * resp) / resp.sum()\n",
    "            self.parameters[i][\"mean\"], self.parameters[i][\"cov\"] = mean, covariance\n",
    "        n_samples = np.shape(X)[0]\n",
    "        self.priors = self.responsibility.sum(axis=0) / n_samples\n",
    "            \n",
    "    def _converged(self, X):\n",
    "        if len(self.responsibilities) < 2:\n",
    "            return False\n",
    "        diff = np.linalg.norm(\n",
    "            self.responsibilities[-1] - self.responsibilities[-2])\n",
    "        return diff <= self.tolerance\n",
    "    def predict(self,X):\n",
    "        self._init_random_gaussians(X)\n",
    "        for _ in range(self.max_iterations):\n",
    "            self._expectation(X)\n",
    "            self._maximization(X)\n",
    "            if self._converged(X):\n",
    "                break\n",
    "        self._expectation(X)\n",
    "        return self.sample_assignments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.array([[ 0.697  ,0.46 ],\n",
    "                         [ 0.774  ,0.376],\n",
    "                         [ 0.634  ,0.264],\n",
    "                         [ 0.608  ,0.318],\n",
    "                         [ 0.556  ,0.215],\n",
    "                         [ 0.403  ,0.237],\n",
    "                         [ 0.481  ,0.149],\n",
    "                         [ 0.437  ,0.211],\n",
    "                         [ 0.666  ,0.091],\n",
    "                         [ 0.243  ,0.267],\n",
    "                         [ 0.245  ,0.057],\n",
    "                         [ 0.343  ,0.099],\n",
    "                         [ 0.639  ,0.161],\n",
    "                         [ 0.657  ,0.198],\n",
    "                         [ 0.36   ,0.37 ],\n",
    "                         [ 0.593  ,0.042],\n",
    "                         [ 0.719  ,0.103],\n",
    "                         [ 0.359  ,0.188],\n",
    "                         [ 0.339  ,0.241],\n",
    "                         [ 0.282  ,0.257],\n",
    "                         [ 0.748  ,0.232],\n",
    "                         [ 0.714  ,0.346],\n",
    "                         [ 0.483  ,0.312],\n",
    "                         [ 0.478  ,0.437],\n",
    "                         [ 0.525  ,0.369],\n",
    "                         [ 0.751  ,0.489],\n",
    "                         [ 0.532  ,0.472],\n",
    "                         [ 0.473  ,0.376],\n",
    "                         [ 0.725  ,0.445],\n",
    "                         [ 0.446  ,0.459]])\n",
    "clf = GaussianMixtureModel(k=3,max_iterations=10000,tolerance=1e-15)\n",
    "res = clf.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "k0 = X[res==0]\n",
    "k1 = X[res==1]\n",
    "k2 = X[res==2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x117466eb8>"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD8CAYAAACSCdTiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFClJREFUeJzt3W+MVNd9xvHn6QYXQhtoy6ZrA6mx4sQl7iahEyJZlZoK\nBeNShN0gglM1alLJopLjTaVacSoVrZwXaZRKzhI7sZBlJa+KVo6D7ICLJaI2bd1KDNhZGydUlDRh\nN4uyjmuirJYa8K8vZgZml/1zZ3dm7twz34+EdufMZeZ39y4Pd8499xxHhAAAafqVvAsAALQOIQ8A\nCSPkASBhhDwAJIyQB4CEEfIAkDBCHgASRsgDQMIIeQBI2NvyeuM1a9bEzTffnNfbA0AhnThx4rWI\n6M26fW4hf/PNN6tcLuf19gBQSLZ/3Mj2dNcAQMIIeQBIGCEPAAkj5AEgYYQ8ACSMkAeAhGUKedvb\nbJ+2fcb2Q7M8/xHbF2y/VP2zr/mlAgAateA4eds9kh6T9FFJo5KO234mIl6dsem/RsSftKBGAMAi\nZTmT3yzpTEScjYg3JR2UtLO1ZQEAmiFLyK+VdK7u8Wi1baY7bI/Yfs72+2Z7Idv32S7bLk9MTCyi\nXABAI5p14fWkpHdFRL+kr0o6NNtGEXEgIkoRUertzTz1AgBgkbLMXTMmaX3d43XVtqsi4hd13x+x\n/TXbayLiteaUCQCd4dCLY/ry0dP66RtTumn1Cj1453t19wdn69zoDFnO5I9LutX2Bts3SNoj6Zn6\nDWz32Xb1+83V1/15s4sFgDwdenFMn3/6ZY29MaWQNPbGlD7/9Ms69OLYgn83LwuGfERclnS/pKOS\nfiBpOCJO2d5re291s12SXrH9fUn7Je2JiGhV0QCQhy8fPa2pS1emtU1duqIvHz2dU0ULyzTVcEQc\nkXRkRtvjdd8/KunR5pYGAJ3lp29MNdTeCbjjFQAyumn1iobaOwEhDwAZPXjne7ViWc+0thXLevTg\nne/NqaKF5bYyFAAUTW0UTZFG1xDywBIVbUgdlubuD64t1PEl5IElqA2pq424qA2pk1SoIEC66JMH\nlqCIQ+rQXQh5YAmKOKQO3YWQB5agiEPq0F0IeWAJijikDt2FC6/IVdFHphRxSF0Kiv57006EPHKT\nysiUog2pK7pUfm/ahe4a5IaRKVgMfm8aQ8gjN4xMwWLwe9MYQh65YWQKFoPfm8YQ8sgNI1OwGPze\nNIYLr8gNI1OwGM34vemm0TnOawGnUqkU5XI5l/cG0L1mjs6RKp8Evvinv1eIoLd9IiJKWbenuwZA\nV+m20Tl01yAJ3fTxG0vTbaNzOJNH4dU+fo+9MaXQtZtjDr04lndp6EDdNjqHkEfhddvHbyxNt43O\nobsGhddtH7+xNN02qouQR+HdtHqFxmYJ9FQ/fmPpumm+IbprUHjd9vEbaARn8ii8bvv4DTSCkEcS\nuunjN9AIumsAIGGEPAAkjJAHgIQR8gCQMEIeABJGyANAwgh5AEgYIQ8ACcsU8ra32T5t+4zth+bZ\n7kO2L9ve1bwSAQCLtWDI2+6R9JikuyRtlHSv7Y1zbPclSc83u0gAwOJkOZPfLOlMRJyNiDclHZS0\nc5btPiPpW5J+1sT6AABLkCXk10o6V/d4tNp2le21ku6R9PX5Xsj2fbbLtssTExON1goAaFCzLrx+\nRdLnIuKt+TaKiAMRUYqIUm9vb5PeGgAwlyyzUI5JWl/3eF21rV5J0kHbkrRG0h/bvhwRh5pSJQBg\nUbKE/HFJt9reoEq475H0ifoNImJD7Xvb35D0HQIeAPK3YMhHxGXb90s6KqlH0pMRccr23urzj7e4\nRgDAImVaNCQijkg6MqNt1nCPiL9YelkAgGbgjlcASBghj0I7fPawtj61Vf3f7NfWp7bq8NnDeZcE\nXG9kWHrkdmlwdeXryHDb3po1XlFYh88e1uALg7p45aIkaXxyXIMvDEqStt+yPcfKgDojw9KzD0iX\npiqPL5yrPJak/t0tf3vO5FFYQyeHrgZ8zcUrFzV0ciinioBZHHv4WsDXXJqqtLcBIV9w3dxdcX7y\nfEPtzdLNP3MswoXRxtqbjJAvsFp3xfjkuEJxtbuiW0Knb2VfQ+3N0O0/cyzCqnWNtTcZIV9g3d5d\nMbBpQMt7lk9rW96zXAObBlr2nt3+M8cibNknLVsxvW3Zikp7G3DhtcDy6q7oFLWLq0Mnh3R+8rz6\nVvZpYNNASy+6dvvPHItQu7h67OFKF82qdZWAb8NFV4mQL7S+lX0anxyftb1bbL9le1tH0vAzx6L0\n725bqM9Ed02B5dFd0e34maNoOJMvsDy6K7odP3MUjSMilzculUpRLpdzeW8AKCrbJyKilHV7umsA\nIGGEPAAkjJAHkK4cJwbrFFx4BZCmnCcG6xScyQNIU84Tg3UKQh5AmnKeGKxTEPIA0pTzxGCdgpAH\nkKacJwbrFIQ8gDT175Z27JdWrZfkytcd+7vqoqvE6BoAKctxYrBOwZk8ACSMkAeAhBHyAJAwQh4A\nEkbIA13k8NnD2vrUVvV/s19bn9rKAuRdgNE1QJc4fPawBl8YvLoQ+fjkuAZfGJQkFj1JGGfyQJcY\nOjl0NeBrLl65qKGTQzlVhHYg5JEcuiRmd37yfEPtSAMhj6TUuiTGJ8cViqtdEgS91Leyr6F2pIGQ\nR1LokpjbwKYBLe9ZPq1tec9yDWwayKkitAMXXpEUuiTmVru4OnRySOcnz6tvZZ8GNg1w0TVxmULe\n9jZJQ5J6JD0REX8/4/mdkr4g6S1JlyV9NiL+rcm1AgvqW9mn8cnxWdtRCXpCvbss2F1ju0fSY5Lu\nkrRR0r22N87Y7Jik90fEByR9WtITzS4UyIIuCWC6LGfymyWdiYizkmT7oKSdkl6tbRARv6zbfqWk\naGaRQFZ0SQDTZQn5tZLO1T0elfThmRvZvkfSFyW9UxL/opAbuiSAa5o2uiYivh0Rt0m6W5X++evY\nvs922XZ5YmKiWW8NAJhDlpAfk7S+7vG6atusIuJ7km6xvWaW5w5ERCkiSr29vQ0XCwBoTJaQPy7p\nVtsbbN8gaY+kZ+o3sP1u265+v0nSr0r6ebOLBQA0ZsE++Yi4bPt+SUdVGUL5ZEScsr23+vzjkj4m\n6ZO2L0makvTxiODiKwDkzHllcalUinK5nMt7A0BR2T4REaWs2zOtAQAkjJAHgIQR8gCQMEIeABJG\nyANAwgh5AEhYIUOe5d0AIJvCLRrCivNATkaGpWMPSxdGpVXrpC37pP7deVeFBRTuTJ7l3YAcjAxL\nzz4gXTgnKSpfn32g0o6OVriQZ3k3IAfHHpYuTU1vuzRVaUdHK1zIt3LFefr6gTlcGG2sHR2jcCHf\nquXdan3945PjCsXVvn6CHlClD76RdnSMwoX89lu2a/COQd248kZZ1o0rb9TgHYNLvuhKXz8wjy37\npGUrprctW1FpR0cr3OgaqTXLu9HXD8yjNoqG0TWFU8iQb4W+lX0anxyftR2AKoFOqBdO4bprWqVV\nff0AkCfO5Ktq3T9DJ4d0fvK8+lb2aWDTADdYASg0Qr5OK/r6G3X47GH+owHQNIR8B2HKBgDNRp98\nB+nEYZzcIAYUG2fyHaTThnHyyQIoPs7kO0grp2xYjE78ZAHkZmRYeuR2aXB15WtBJmcj5DtIpw3j\n7LRPFkBuCjwLJyHfQVo1ZcNiddonCyA3BZ6Fkz75DtMJwzhrBjYNTOuTl7hBDF2qwLNwEvKYEzeI\nAVWr1lW7amZp73CEPObVSZ8sgNxs2Vfpg6/vsinILJz0yQPAQvp3Szv2S6vWS3Ll6479hZiwjTN5\nICdMYVEwBZ2Fk5AHcsCNZmgXumvQNkyRcA03mqFdOJNHW3DmOh03mqFdOJNHW3DmOh03mqFdCHm0\nBWeu03XaFBZIV6aQt73N9mnbZ2w/NMvzf2Z7xPbLtl+w/f7ml4oi48x1uk6bwgLpWrBP3naPpMck\nfVTSqKTjtp+JiFfrNvuRpD+MiP+1fZekA5I+3IqCUUxMkXA9bjRDO2S58LpZ0pmIOCtJtg9K2inp\nashHxAt12/+npM6/1xdtxRQJXWpkuDKJ14XRyhQAW/YVcqx5kWUJ+bWS6idtGNX8Z+l/Kem5pRSF\nNHHm2mVq0/PWpgKoTc8rEfRt1NQLr7b/SJWQ/9wcz99nu2y7PDEx0cy3BtBpCjw9b0qyhPyYpPV1\nj9dV26ax3S/pCUk7I+Lns71QRByIiFJElHp7exdTL4CiKPD0vCnJEvLHJd1qe4PtGyTtkfRM/Qa2\n3yXpaUl/HhH/1fwyMRvuIEVHm2sa3gJMz5uSBUM+Ii5Lul/SUUk/kDQcEads77W9t7rZPkm/Jelr\ntl+yXW5ZxZB07Q7S8clxheLqHaQEPTrGln2V6XjrFWR63pQ4InJ541KpFOUy/xcs1tantmp8cvy6\n9htX3qjndz2fQ0XALBhd03S2T0REKev2zF1TUNxBikIo6PS8KWFag4LiDlIAWRDyBcXcJwCyoLum\noLiDFEAWhHyBcQcpgIXQXQMACSPkASBhhDwAJIyQB9C4kWHpkdulwdWVryPDeVeEOXDhFUBjmEK4\nUDiTB9AYphAuFEIeQGOYQrhQCHkAjWEK4UIh5AE0himEC4WQR1JYSKUN+ndLO/ZLq9ZLcuXrjv1c\ndO1QjK5BMmoLqVy8clGSri6kIonpH5qNKYQLgzN5JGPo5NDVgK+5eOWihk4O5VQRkD9CHslgIRXg\neoQ8ksFCKsD1CHkkg4VUgOtx4RXJYCEV4HqEPJLCQirAdHTXAEDCCHkASBghDwAJI+QBIGGEPAAk\njJAHgIQR8gCQMEIeABJGyANAwgh5AEgYIQ8ACSPkAXSHkWHpkdulwdWVryPDeVfUFplC3vY226dt\nn7H90CzP32b7P2z/n+2/aX6ZALAEI8PSsw9IF85JisrXZx/oiqBfMORt90h6TNJdkjZKutf2xhmb\nvS7pAUn/0PQKuxQLUgNNdOxh6dLU9LZLU5X2xGU5k98s6UxEnI2INyUdlLSzfoOI+FlEHJd0qQU1\ndp3agtTjk+MKxdUFqQl6YJEujDbWnpAsIb9W0rm6x6PVNrQIC1IDTbZqXWPtCWnrhVfb99ku2y5P\nTEy0860LhQWpkbQ8LoBu2SctWzG9bdmKSnvisoT8mKT1dY/XVdsaFhEHIqIUEaXe3t7FvERXYEFq\nJCuvC6D9u6Ud+6VV6yW58nXH/kp74rKE/HFJt9reYPsGSXskPdPasrobC1IjWXleAO3fLf31K9Lg\nG5WvXRDwUoY1XiPisu37JR2V1CPpyYg4ZXtv9fnHbfdJKkt6h6S3bH9W0saI+EULa08WC1IjWV18\nATQvmRbyjogjko7MaHu87vvzqnTjoElYkBpJWrWu2lUzSztagjteAbRPF18AzQshD6B9uvgCaF4y\nddcAQNP07ybU24gzeQBIGCEPAAkj5AEgYYQ8ACSMkAeAhBHyAJAwQh4AElbYkGflJABYWCFvhqqt\nnFRbWKO2cpIk5nsBgDqFPJNn5SQAyKaQIc/KSQCQTSFDnpWTACCbQoY8KycBQDaFvPDKyknALEaG\nK8voXRitLMKxZR+zPaKYIS+xchIwTW2B7Nr6qbUFsiWCvssVsrsGwAx5LpCNjkbIAylggWzMgZAH\nUjDXQtgskN31CHkgBSyQjTkQ8kAKWCAbcyjs6BoAM7BANmbBmTwAJIyQB4CEEfIAkDBCHgASRsgD\nQMIIeQBImCMinze2JyT9OJc3b641kl7Lu4gWSn3/pPT3MfX9k7prH38nInqz/qXcQj4VtssRUcq7\njlZJff+k9Pcx9f2T2Mf50F0DAAkj5AEgYYT80h3Iu4AWS33/pPT3MfX9k9jHOdEnDwAJ40weABJG\nyGdge5vt07bP2H5olud32h6x/ZLtsu0/yKPOpVhoH+u2+5Dty7Z3tbO+pcpwDD9i+0L1GL5ku3AT\nsWc5htX9fMn2Kdv/0u4alyrDcXyw7hi+YvuK7d/Mo9bFyLB/q2w/a/v71WP4qQVfNCL4M88fST2S\n/lvSLZJukPR9SRtnbPNrutb11S/ph3nX3ex9rNvuu5KOSNqVd91NPoYfkfSdvGtt8T6ulvSqpHdV\nH78z77qbvY8ztt8h6bt5193kY/i3kr5U/b5X0uuSbpjvdTmTX9hmSWci4mxEvCnpoKSd9RtExC+j\n+lOXtFJS0S50LLiPVZ+R9C1JP2tncU2Qdf+KLMs+fkLS0xHxE0mKiNSP472S/rEtlTVHlv0LSb9u\n26qcXL4u6fJ8L0rIL2ytpHN1j0erbdPYvsf2DyUdlvTpNtXWLAvuo+21ku6R9PU21tUsmY6hpDuq\n3W7P2X5fe0prmiz7+B5Jv2H7n22fsP3JtlXXHFmPo2y/XdI2VU5KiiLL/j0q6Xcl/VTSy5IGIuKt\n+V6UkG+SiPh2RNwm6W5JX8i7nhb4iqTPLfQLVWAnVenG6Jf0VUmHcq6nFd4m6fclbZd0p6S/s/2e\nfEtqmR2S/j0iXs+7kCa7U9JLkm6S9AFJj9p+x3x/gZBf2Jik9XWP11XbZhUR35N0i+01rS6sibLs\nY0nSQdv/I2mXpK/Zvrs95S3ZgvsXEb+IiF9Wvz8iaVmCx3BU0tGImIyI1yR9T9L721RfMzTyb3GP\nitVVI2Xbv0+p0uUWEXFG0o8k3Tbvq+Z9saHT/6hy9nNW0gZduxjyvhnbvFvXLrxuqh4Y5117M/dx\nxvbfULEuvGY5hn11x3CzpJ+kdgxV+Zh/rLrt2yW9Iun2vGtv5j5Wt1ulSl/1yrxrbsEx/Lqkwer3\nv13NmjXzvS4LeS8gIi7bvl/SUVWufj8ZEads760+/7ikj0n6pO1LkqYkfTyqR6EIMu5jYWXcv12S\n/sr2ZVWO4Z7UjmFE/MD2P0kakfSWpCci4pX8qm5MA7+n90h6PiImcyp1UTLu3xckfcP2y5KsShfq\nvLNvcscrACSMPnkASBghDwAJI+QBIGGEPAAkjJAHgIQR8gCQMEIeABJGyANAwv4fbpok1PJKebMA\nAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x1171fab00>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(k0[:,0],k0[:,1])\n",
    "plt.scatter(k1[:,0],k1[:,1])\n",
    "plt.scatter(k2[:,0],k2[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
